{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#Pre and post processing\n",
    "在某些情形下，需要对Index做前处理或后处理。\n",
    "\n",
    "##ID映射\n",
    "默认情况下，faiss会为每个输入的向量记录一个次序id，在使用中也可以为向量指定任意我们需要的id。  \n",
    "部分index类型有add_with_ids方法，可以为每个向量对应一个64-bit的id，搜索的时候返回这个指定的id。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[100000 100001 100002 ... 101997 101998 101999]\n",
      "2000\n"
     ]
    }
   ],
   "source": [
    "#导入faiss\n",
    "import sys\n",
    "sys.path.append('/home/maliqi/faiss/python/')\n",
    "import faiss\n",
    "import numpy as np \n",
    "\n",
    "#获取数据和Id\n",
    "d = 512\n",
    "n_data = 2000\n",
    "data = np.random.rand(n_data, d).astype('float32')\n",
    "ids = np.arange(100000, 102000)  #id设定为6位数整数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[100000 100383 101007 101444 100729]\n",
      " [100001 100880 101693 100004 100964]\n",
      " [100002 101113 101998 101017 101768]\n",
      " [100003 100694 101701 101608 100831]\n",
      " [100004 100111 100564 100541 100513]]\n"
     ]
    }
   ],
   "source": [
    "nlist = 10\n",
    "quantizer = faiss.IndexFlatIP(d)\n",
    "index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)\n",
    "index.train(data)\n",
    "index.add_with_ids(data, ids)\n",
    "d, i = index.search(data[:5], 5)\n",
    "print(i)  #返回的id应该是我们自己设定的"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "但是对有些Index类型，并不支持add_with_ids，因此需要与其他Index类型结合，将默认的id映射到指定id，用IndexIDMap类实现。  \n",
    "指定的ids不能是字符串，只能是整数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Error in virtual void faiss::Index::add_with_ids(faiss::Index::idx_t, const float*, const long int*) at Index.cpp:46: add_with_ids not implemented for this type of index",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-4de928a09ab9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfaiss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIndexFlatL2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mindex\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_with_ids\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/home/maliqi/faiss/python/faiss/__init__.py\u001b[0m in \u001b[0;36mreplacement_add_with_ids\u001b[0;34m(self, x, ids)\u001b[0m\n\u001b[1;32m    104\u001b[0m         \u001b[0;32massert\u001b[0m \u001b[0md\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    105\u001b[0m         \u001b[0;32massert\u001b[0m \u001b[0mids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'not same nb of vectors as ids'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 106\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_with_ids_c\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mswig_ptr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mswig_ptr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    108\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mreplacement_assign\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/maliqi/faiss/python/faiss/swigfaiss.py\u001b[0m in \u001b[0;36madd_with_ids\u001b[0;34m(self, n, x, xids)\u001b[0m\n\u001b[1;32m   1316\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1317\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0madd_with_ids\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1318\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0m_swigfaiss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIndex_add_with_ids\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1319\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1320\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0msearch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdistances\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Error in virtual void faiss::Index::add_with_ids(faiss::Index::idx_t, const float*, const long int*) at Index.cpp:46: add_with_ids not implemented for this type of index"
     ]
    }
   ],
   "source": [
    "index = faiss.IndexFlatL2(data.shape[1]) \n",
    "index.add_with_ids(data, ids)  #报错"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "index2 = faiss.IndexIDMap(index)  \n",
    "index2.add_with_ids(data, ids)  #将index的id映射到index2的id,会维持一个映射表"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##数据转换\n",
    "有些时候需要在索引之前转换数据。转换类继承了VectorTransform类，将输入向量转换为输出向量。  \n",
    "1)随机旋转,类名RandomRotationMatri,用以均衡向量中的元素，一般在IndexPQ和IndexLSH之前；  \n",
    "2）PCA,类名PCAMatrix，降维；  \n",
    "3）改变维度，类名RemapDimensionsTransform，可以升高或降低向量维数\n",
    "\n",
    "###举例：PCA降维（通过IndexPreTransform）\n",
    "输入向量是2048维，需要减少到16byte。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.random.rand(n_data, 2048).astype('float32')\n",
    "# the IndexIVFPQ will be in 256D not 2048\n",
    "coarse_quantizer = faiss.IndexFlatL2 (256) \n",
    "sub_index = faiss.IndexIVFPQ (coarse_quantizer, 256, 16, 16, 8)\n",
    "# PCA 2048->256\n",
    "# 降维后随机旋转 (第四个参数)\n",
    "pca_matrix = faiss.PCAMatrix (2048, 256, 0, True) \n",
    "\n",
    "#- the wrapping index\n",
    "index = faiss.IndexPreTransform (pca_matrix, sub_index)\n",
    "\n",
    "# will also train the PCA\n",
    "index.train(data)  #数据需要是2048维\n",
    "# PCA will be applied prior to addition\n",
    "index.add(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###举例：升维\n",
    "有时候需要在向量中插入0升高维度，一般是我们需要 1）d是4的整数倍，有利于举例计算； 2）d是M的整数倍。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = 512\n",
    "M = 8   #M是在维度方向上分割的子空间个数\n",
    "d2 = int((d + M - 1) / M) * M\n",
    "remapper = faiss.RemapDimensionsTransform (d, d2, True)\n",
    "index_pq = faiss.IndexPQ(d2, M, 8)\n",
    "index = faiss.IndexPreTransform (remapper, index_pq) #后续可以添加数据/索引"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##对搜索结果重新排序\n",
    "当查询向量时，可以用真实距离值对结果进行重新排序。  \n",
    "在下面的例子中，搜索阶段会首先选取4*10个结果，然后对这些结果计算真实距离值，再从中选取10个结果返回。IndexRefineFlat保存了全部的向量信息，内存开销不容小觑。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[   0  434 1647 1501  867  658  822 1164  490 1430]\n",
      " [   1 1035  369  392  866 1645 1961 1469 1946  175]\n",
      " [   2  466 1183  403  427  505  394  759  633  746]\n",
      " [   3 1668 1798 1293  965 1484  755  315 1633 1453]\n",
      " [   4  309  715 1204  996  239 1381   48  707 1311]]\n"
     ]
    }
   ],
   "source": [
    "data = np.random.rand(n_data, d).astype('float32')\n",
    "nbits_per_index = 4\n",
    "q = faiss.IndexPQ (d, M, nbits_per_index)\n",
    "rq = faiss.IndexRefineFlat (q)\n",
    "rq.train (data)\n",
    "rq.add (data)\n",
    "rq.k_factor = 4\n",
    "dis, ind = rq.search (data[:5], 10)\n",
    "print(ind)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##综合多个index返回的结果\n",
    "当数据集分布在多个index中，需要在每个index中都执行搜索，然后使用IndexShards综合得到结果。同样也适用于index分布在不同的GPU的情况。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
